from scipy.optimize import curve_fit
from scipy.special import factorial
from scipy.stats import poisson
import numpy as np
import scipy.stats as ss
import matplotlib.pyplot as plt
from scipy.stats import expon
from exp_mixture_model import EMM
import matplotlib.transforms as mtransforms
plt.rcParams["font.family"] = "Arial"
# get poisson deviated random numbers
#data = np.random.poisson(2, 1000)
data0 = np.array([0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,1,0,0,1,0,0,0,0,
0,1,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,1,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0,1,0,0,0,0,0,1,0,0,0,1,0,0,0,
1,0,0,1,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0])
nchunk=30
data = np.add.reduceat(data0, np.arange(0, len(data0), nchunk))

gaps = np.diff(np.where(data0 == 1))[0]

##def fit_function(lamb):
##    '''exponential function, parameter lamb is the fit parameter'''
##    return expon.pdf(lamb)
##parameters, cov_matrix = curve_fit(fit_function, gaps)


### the bins should be of integer width, because poisson is an integer distribution
##bins = np.arange(11) - 0.5
##entries, bin_edges, patches = plt.hist(data, bins=bins, density=True, label='Data')
##
### calculate bin centres
##bin_middles = 0.5 * (bin_edges[1:] + bin_edges[:-1])
##
##
##def fit_function(k, lamb):
##    '''poisson function, parameter lamb is the fit parameter'''
##    return poisson.pmf(k, lamb)
##
##
### fit with curve_fit
##parameters, cov_matrix = curve_fit(fit_function, bin_middles, entries)
##
### plot poisson-deviation with fitted parameter
##x_plot = np.arange(0, 15)
##
##plt.plot(
##    x_plot,
##    fit_function(x_plot, *parameters),
##    marker='o', linestyle='',
##    label='Fit result',
##)
##plt.legend()
##plt.title("Poisson-Fit vs Histogram of Volcanic Eruptions in "+str(nchunk)+"-year intervals")
##
##P = ss.expon.fit(gaps)
##print(P)
##rX = np.linspace(0,40, 100)
##rP = ss.expon.pdf(rX, *P)
##plt.figure()
##plt.hist(gaps,density=True)
##plt.plot(rX, rP)

model = EMM()#k=2,n_iter=1)
offset=2.8
pi, mu = model.fit(gaps-offset)

#
model2=EMM(k=model.k_final)

data_real = np.genfromtxt(open("toyKFmodelData7.csv", "rb"),dtype=float, delimiter=',') 
opt_depth=data_real[:,3]*0.001
dates=data_real[:,0]



strengths=np.array([60.1666666666667,13.6583333333333,142.908333333333,37.1083333333333,39.125,
                    18.2833333333333,71.5166666666667,10.275,24.0833333333333,9.35833333333333,9.825,
                    8.24166666666667,71.6916666666667,34.4,30.125,9.16666666666666,75.2416666666666,
                    13.5666666666667,121.141666666667])
#plt.figure()
#plt.hist(strengths,density=True)
S = ss.expon.fit(strengths)

rX = np.linspace(0,150, 100)
rS = ss.expon.pdf(rX, *S)
#plt.plot(rX, rS)


##building the model:
def genEruptions(nyears):
    numerupt=int(nyears/4) #empirically guessed
    #timegaps = ss.expon.rvs(loc=P[0], scale=P[1], size=numerupt)
    gentimegaps=model2.generate(numerupt,model.pi, model.mu)
    times=np.cumsum(gentimegaps+offset)-offset
    #print(times)
    lenseries=int(times[-1])+4
    strentimes=np.zeros(lenseries)
    strengths= ss.expon.rvs(loc=S[0], scale=S[1], size=numerupt)
    strengthsf=strengths.copy()
    strengthsf[strengths<7]=0
    strentimes[times.astype(int)+1]=strengthsf

    meansoffset=[0.523437646,0,0.582071463,0.340446011]
    stdevsoffset=[0.256605137,0,0.170816437,0.17366678]
    for i in range(4):
        if i==1:
            continue
        else:
            add=np.random.normal(meansoffset[i],stdevsoffset[i],numerupt)
            strentimes[times.astype(int)+i]=strentimes[times.astype(int)+i]+add*strengthsf

    baseline=ss.truncnorm.rvs(0,300,loc=3.262,scale=2.663,size=lenseries)

    for t in range(lenseries):
        if (strentimes[t]<=0):
            strentimes[t]=baseline[t]
    return(strentimes[0:nyears])

        
if __name__=="__main__":
    model.print_result()
    #model.plot_survival_probability()
    print(S)

    fig, (ax0, ax1) = plt.subplots(1, 2, gridspec_kw={'width_ratios': [1, 2.5], 'wspace':0.4},figsize=(10,4))
    ax0.plot(dates,opt_depth)
    ax0.set_xlabel("Year")
    ax0.set_ylabel("Optical Depth (unitless) at ~ 525-550nm")
    ax0.set_title("Historical Volcanic Eruptions")
    ax0.set_ylim(0,0.2)
    
    ax1.plot(genEruptions(int(171*2.5))*0.001)
    ax1.set_xlabel("Year")
    ax1.set_ylabel("Optical Depth (unitless) at ~ 525-550nm")
    ax1.set_title("Simulated Volcanic Eruptions (Sampled)")
    ax1.set_ylim(0,0.2)

    axlabels=['a)','b)','c)']
    axs=(ax0,ax1)
    for i in range(len(axs)):
        ax=axs[i]
        label=axlabels[i]
    # label physical distance to the left and up:
        trans = mtransforms.ScaledTranslation(-40/72, 1/72, fig.dpi_scale_trans) #-20/72, 7/72
        ax.text(0.0, 1.0, label, transform=ax.transAxes + trans,
            fontsize='large', va='bottom') #, fontfamily='serif')
    plt.show()
